{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Seq2Seq\n",
    "\n",
    "本篇代码将实现一个基础版的Seq2Seq，输入一个单词（字母序列），模型将返回一个对字母排序后的“单词”。\n",
    "\n",
    "基础Seq2Seq主要包含三部分：\n",
    "\n",
    "- Encoder\n",
    "- 隐层状态向量（连接Encoder和Decoder）\n",
    "- Decoder\n",
    "\n",
    "# 任务：\n",
    "按字母顺序排序：hello  -->  ehllo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 查看TensorFlow版本\n",
    "\n",
    "http://www.lfd.uci.edu/~gohlke/pythonlibs/#tensorflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "TensorFlow Version: 1.2.0\n"
     ]
    }
   ],
   "source": [
    "from distutils.version import LooseVersion\n",
    "import tensorflow as tf\n",
    "from tensorflow.python.layers.core import Dense\n",
    "\n",
    "\n",
    "# Check TensorFlow Version\n",
    "assert LooseVersion(tf.__version__) >= LooseVersion('1.1'), 'Please use TensorFlow version 1.1 or newer'\n",
    "print('TensorFlow Version: {}'.format(tf.__version__))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据加载"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import time\n",
    "import tensorflow as tf\n",
    "\n",
    "with open('data/letters_source.txt', 'r', encoding='utf-8') as f:\n",
    "    source_data = f.read()\n",
    "\n",
    "with open('data/letters_target.txt', 'r', encoding='utf-8') as f:\n",
    "    target_data = f.read()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['bsaqq',\n",
       " 'npy',\n",
       " 'lbwuj',\n",
       " 'bqv',\n",
       " 'kial',\n",
       " 'tddam',\n",
       " 'edxpjpg',\n",
       " 'nspv',\n",
       " 'huloz',\n",
       " 'kmclq']"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 数据预览\n",
    "source_data.split('\\n')[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['abqqs',\n",
       " 'npy',\n",
       " 'bjluw',\n",
       " 'bqv',\n",
       " 'aikl',\n",
       " 'addmt',\n",
       " 'degjppx',\n",
       " 'npsv',\n",
       " 'hlouz',\n",
       " 'cklmq']"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_data.split('\\n')[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 数据预处理"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def extract_character_vocab(data):\n",
    "    '''\n",
    "    构造映射表\n",
    "    '''\n",
    "    special_words = ['<PAD>', '<UNK>', '<GO>',  '<EOS>']\n",
    "\n",
    "    set_words = list(set([character for line in data.split('\\n') for character in line]))\n",
    "    # 这里要把四个特殊字符添加进词典\n",
    "    int_to_vocab = {idx: word for idx, word in enumerate(special_words + set_words)}\n",
    "    vocab_to_int = {word: idx for idx, word in int_to_vocab.items()}\n",
    "\n",
    "    return int_to_vocab, vocab_to_int"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 构造映射表\n",
    "source_int_to_letter, source_letter_to_int = extract_character_vocab(source_data)\n",
    "target_int_to_letter, target_letter_to_int = extract_character_vocab(target_data)\n",
    "\n",
    "# 对字母进行转换\n",
    "source_int = [[source_letter_to_int.get(letter, source_letter_to_int['<UNK>']) \n",
    "               for letter in line] for line in source_data.split('\\n')]\n",
    "target_int = [[target_letter_to_int.get(letter, target_letter_to_int['<UNK>']) \n",
    "               for letter in line] + [target_letter_to_int['<EOS>']] for line in target_data.split('\\n')] \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "collapsed": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[17, 9, 12, 11, 11],\n",
       " [16, 29, 26],\n",
       " [13, 17, 15, 25, 8],\n",
       " [17, 11, 4],\n",
       " [18, 10, 12, 13],\n",
       " [23, 7, 7, 12, 24],\n",
       " [27, 7, 6, 29, 8, 29, 5],\n",
       " [16, 9, 29, 4],\n",
       " [28, 25, 13, 21, 20],\n",
       " [18, 24, 22, 13, 11]]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 查看一下转换结果\n",
    "source_int[:10]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[12, 17, 11, 11, 9, 3],\n",
       " [16, 29, 26, 3],\n",
       " [17, 8, 13, 25, 15, 3],\n",
       " [17, 11, 4, 3],\n",
       " [12, 10, 18, 13, 3],\n",
       " [12, 7, 7, 24, 23, 3],\n",
       " [7, 27, 5, 8, 29, 29, 6, 3],\n",
       " [16, 29, 9, 4, 3],\n",
       " [28, 13, 21, 25, 20, 3],\n",
       " [22, 18, 13, 24, 11, 3]]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "target_int[:10]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 构建模型"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 输入层"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_inputs():\n",
    "    '''\n",
    "    模型输入tensor\n",
    "    '''\n",
    "    inputs = tf.placeholder(tf.int32, [None, None], name='inputs')\n",
    "    targets = tf.placeholder(tf.int32, [None, None], name='targets')\n",
    "    learning_rate = tf.placeholder(tf.float32, name='learning_rate')\n",
    "    \n",
    "    # 定义target序列最大长度（之后target_sequence_length和source_sequence_length会作为feed_dict的参数）\n",
    "    target_sequence_length = tf.placeholder(tf.int32, (None,), name='target_sequence_length')\n",
    "    max_target_sequence_length = tf.reduce_max(target_sequence_length, name='max_target_len')\n",
    "    source_sequence_length = tf.placeholder(tf.int32, (None,), name='source_sequence_length')\n",
    "    \n",
    "    return inputs, targets, learning_rate, target_sequence_length, max_target_sequence_length, source_sequence_length"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Encoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在Encoder端，我们需要进行两步，第一步要对我们的输入进行Embedding，再把Embedding以后的向量传给RNN进行处理。\n",
    "\n",
    "在Embedding中，我们使用[tf.contrib.layers.embed_sequence](https://www.tensorflow.org/api_docs/python/tf/contrib/layers/embed_sequence)，它会对每个batch执行embedding操作。\n",
    "* ## tf.contrib.layers.embed_sequence：##\n",
    "\n",
    "对序列数据执行embedding操作，输入[batch_size, sequence_length]的tensor，返回[batch_size, sequence_length, embed_dim]的tensor。\n",
    "\n",
    "features = [[1,2,3],[4,5,6]]\n",
    "\n",
    "  outputs = tf.contrib.layers.embed_sequence(features, vocab_size, embed_dim)\n",
    "  \n",
    "  如果embed_dim=4，输出结果为\n",
    "  \n",
    "  [  \n",
    "  \t[[0.1,0.2,0.3,0.1],[0.2,0.5,0.7,0.2],[0.1,0.6,0.1,0.2]],    \n",
    "  \t[[0.6,0.2,0.8,0.2],[0.5,0.6,0.9,0.2],[0.3,0.9,0.2,0.2]]    \n",
    "  ]\n",
    "  \n",
    "* ## tf.contrib.rnn.MultiRNNCell：##\n",
    "\n",
    "对RNN单元按序列堆叠。接受参数为一个由RNN cell组成的list。\n",
    "\n",
    "rnn_size代表一个rnn单元中隐层节点数量，layer_nums代表堆叠的rnn cell个数\n",
    "\n",
    "* ## tf.nn.dynamic_rnn：##\n",
    "\n",
    "构建RNN，接受动态输入序列。返回RNN的输出以及最终状态的tensor。\n",
    "\n",
    "dynamic_rnn与rnn的区别在于，dynamic_rnn对于不同的batch，可以接收不同的sequence_length。\n",
    "\n",
    "例如，第一个batch是[batch_size,10]，第二个batch是[batch_size,20]。而rnn只能接收定长的sequence_length。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_encoder_layer(input_data, rnn_size, num_layers,\n",
    "                   source_sequence_length, source_vocab_size, \n",
    "                   encoding_embedding_size):\n",
    "\n",
    "    '''\n",
    "    构造Encoder层\n",
    "    \n",
    "    参数说明：\n",
    "    - input_data: 输入tensor\n",
    "    - rnn_size: rnn隐层结点数量\n",
    "    - num_layers: 堆叠的rnn cell数量\n",
    "    - source_sequence_length: 源数据的序列长度\n",
    "    - source_vocab_size: 源数据的词典大小\n",
    "    - encoding_embedding_size: embedding的大小\n",
    "    '''\n",
    "    # Encoder embedding\n",
    "    encoder_embed_input = tf.contrib.layers.embed_sequence(input_data, source_vocab_size, encoding_embedding_size)\n",
    "\n",
    "    # RNN cell\n",
    "    def get_lstm_cell(rnn_size):\n",
    "        lstm_cell = tf.contrib.rnn.LSTMCell(rnn_size, initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))\n",
    "        return lstm_cell\n",
    "\n",
    "    cell = tf.contrib.rnn.MultiRNNCell([get_lstm_cell(rnn_size) for _ in range(num_layers)])\n",
    "    \n",
    "    encoder_output, encoder_state = tf.nn.dynamic_rnn(cell, encoder_embed_input, \n",
    "                                                      sequence_length=source_sequence_length, dtype=tf.float32)\n",
    "    \n",
    "    return encoder_output, encoder_state"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Decoder"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对target数据进行预处理\n",
    "\n",
    "<img src=\"figure/f1.png\" alt=\"FAO\" width=\"690\" >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def process_decoder_input(data, vocab_to_int, batch_size):\n",
    "    '''\n",
    "    补充<GO>，并移除最后一个字符\n",
    "    '''\n",
    "    # cut掉最后一个字符\n",
    "    ending = tf.strided_slice(data, [0, 0], [batch_size, -1], [1, 1])\n",
    "    decoder_input = tf.concat([tf.fill([batch_size, 1], vocab_to_int['<GO>']), ending], 1)\n",
    "\n",
    "    return decoder_input"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对数据进行embedding\n",
    "\n",
    "同样地，我们还需要对target数据进行embedding，使得它们能够传入Decoder中的RNN。\n",
    "\n",
    "* ## tf.contrib.seq2seq.TrainingHelper：##\n",
    "\n",
    "Decoder端用来训练的函数。\n",
    "\n",
    "这个函数不会把t-1阶段的输出作为t阶段的输入，而是把target中的真实值直接输入给RNN。\n",
    "\n",
    "主要参数是inputs和sequence_length。返回helper对象，可以作为BasicDecoder函数的参数。\n",
    "\n",
    "* ## tf.contrib.seq2seq.GreedyEmbeddingHelper：##\n",
    "\n",
    "它和TrainingHelper的区别在于它会把t-1下的输出进行embedding后再输入给RNN。\n",
    "\n",
    "\n",
    "### 下面的图中代表的是training过程。###\n",
    "\n",
    "在training过程中，我们并不会把每个阶段的预测输出作为下一阶段的输入，下一阶段的输入我们会直接使用target data，这样能够保证模型更加准确。\n",
    "\n",
    "<img src=\"figure/f2.png\" alt=\"FAO\" width=\"690\" >"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def decoding_layer(target_letter_to_int, decoding_embedding_size, num_layers, rnn_size,\n",
    "                   target_sequence_length, max_target_sequence_length, encoder_state, decoder_input):\n",
    "    '''\n",
    "    构造Decoder层\n",
    "    \n",
    "    参数：\n",
    "    - target_letter_to_int: target数据的映射表\n",
    "    - decoding_embedding_size: embed向量大小\n",
    "    - num_layers: 堆叠的RNN单元数量\n",
    "    - rnn_size: RNN单元的隐层结点数量\n",
    "    - target_sequence_length: target数据序列长度\n",
    "    - max_target_sequence_length: target数据序列最大长度\n",
    "    - encoder_state: encoder端编码的状态向量\n",
    "    - decoder_input: decoder端输入\n",
    "    '''\n",
    "    # 1. Embedding\n",
    "    target_vocab_size = len(target_letter_to_int)\n",
    "    decoder_embeddings = tf.Variable(tf.random_uniform([target_vocab_size, decoding_embedding_size]))\n",
    "    decoder_embed_input = tf.nn.embedding_lookup(decoder_embeddings, decoder_input)\n",
    "\n",
    "    # 2. 构造Decoder中的RNN单元\n",
    "    def get_decoder_cell(rnn_size):\n",
    "        decoder_cell = tf.contrib.rnn.LSTMCell(rnn_size,\n",
    "                                           initializer=tf.random_uniform_initializer(-0.1, 0.1, seed=2))\n",
    "        return decoder_cell\n",
    "    cell = tf.contrib.rnn.MultiRNNCell([get_decoder_cell(rnn_size) for _ in range(num_layers)])\n",
    "     \n",
    "    # 3. Output全连接层\n",
    "    output_layer = Dense(target_vocab_size,\n",
    "                         kernel_initializer = tf.truncated_normal_initializer(mean = 0.0, stddev=0.1))\n",
    "\n",
    "\n",
    "    # 4. Training decoder\n",
    "    with tf.variable_scope(\"decode\"):\n",
    "        # 得到help对象\n",
    "        training_helper = tf.contrib.seq2seq.TrainingHelper(inputs=decoder_embed_input,\n",
    "                                                            sequence_length=target_sequence_length,\n",
    "                                                            time_major=False)\n",
    "        # 构造decoder\n",
    "        training_decoder = tf.contrib.seq2seq.BasicDecoder(cell,\n",
    "                                                           training_helper,\n",
    "                                                           encoder_state,\n",
    "                                                           output_layer) \n",
    "        training_decoder_output, _ = tf.contrib.seq2seq.dynamic_decode(training_decoder,\n",
    "                                                                       impute_finished=True,\n",
    "                                                                       maximum_iterations=max_target_sequence_length)\n",
    "    # 5. Predicting decoder\n",
    "    # 与training共享参数\n",
    "    with tf.variable_scope(\"decode\", reuse=True):\n",
    "        # 创建一个常量tensor并复制为batch_size的大小\n",
    "        start_tokens = tf.tile(tf.constant([target_letter_to_int['<GO>']], dtype=tf.int32), [batch_size], \n",
    "                               name='start_tokens')\n",
    "        predicting_helper = tf.contrib.seq2seq.GreedyEmbeddingHelper(decoder_embeddings,\n",
    "                                                                start_tokens,\n",
    "                                                                target_letter_to_int['<EOS>'])\n",
    "        predicting_decoder = tf.contrib.seq2seq.BasicDecoder(cell,\n",
    "                                                        predicting_helper,\n",
    "                                                        encoder_state,\n",
    "                                                        output_layer)\n",
    "        predicting_decoder_output, _ = tf.contrib.seq2seq.dynamic_decode(predicting_decoder,\n",
    "                                                            impute_finished=True,\n",
    "                                                            maximum_iterations=max_target_sequence_length)\n",
    "    \n",
    "    return training_decoder_output, predicting_decoder_output"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Seq2Seq\n",
    "\n",
    "上面已经构建完成Encoder和Decoder，下面将这两部分连接起来，构建seq2seq模型"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def seq2seq_model(input_data, targets, lr, target_sequence_length, \n",
    "                  max_target_sequence_length, source_sequence_length,\n",
    "                  source_vocab_size, target_vocab_size,\n",
    "                  encoder_embedding_size, decoder_embedding_size, \n",
    "                  rnn_size, num_layers):\n",
    "    \n",
    "    # 获取encoder的状态输出\n",
    "    _, encoder_state = get_encoder_layer(input_data, \n",
    "                                  rnn_size, \n",
    "                                  num_layers, \n",
    "                                  source_sequence_length,\n",
    "                                  source_vocab_size, \n",
    "                                  encoding_embedding_size)\n",
    "    \n",
    "    \n",
    "    # 预处理后的decoder输入\n",
    "    decoder_input = process_decoder_input(targets, target_letter_to_int, batch_size)\n",
    "    \n",
    "    # 将状态向量与输入传递给decoder\n",
    "    training_decoder_output, predicting_decoder_output = decoding_layer(target_letter_to_int, \n",
    "                                                                       decoding_embedding_size, \n",
    "                                                                       num_layers, \n",
    "                                                                       rnn_size,\n",
    "                                                                       target_sequence_length,\n",
    "                                                                       max_target_sequence_length,\n",
    "                                                                       encoder_state, \n",
    "                                                                       decoder_input) \n",
    "    \n",
    "    return training_decoder_output, predicting_decoder_output\n",
    "    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 超参数\n",
    "# Number of Epochs\n",
    "epochs = 60\n",
    "# Batch Size\n",
    "batch_size = 128\n",
    "# RNN Size\n",
    "rnn_size = 50\n",
    "# Number of Layers\n",
    "num_layers = 2\n",
    "# Embedding Size\n",
    "encoding_embedding_size = 15\n",
    "decoding_embedding_size = 15\n",
    "# Learning Rate\n",
    "learning_rate = 0.001"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# 构造graph\n",
    "train_graph = tf.Graph()\n",
    "\n",
    "with train_graph.as_default():\n",
    "    \n",
    "    # 获得模型输入    \n",
    "    input_data, targets, lr, target_sequence_length, max_target_sequence_length, source_sequence_length = get_inputs()\n",
    "    \n",
    "    training_decoder_output, predicting_decoder_output = seq2seq_model(input_data, \n",
    "                                                                      targets, \n",
    "                                                                      lr, \n",
    "                                                                      target_sequence_length, \n",
    "                                                                      max_target_sequence_length, \n",
    "                                                                      source_sequence_length,\n",
    "                                                                      len(source_letter_to_int),\n",
    "                                                                      len(target_letter_to_int),\n",
    "                                                                      encoding_embedding_size, \n",
    "                                                                      decoding_embedding_size, \n",
    "                                                                      rnn_size, \n",
    "                                                                      num_layers)    \n",
    "    \n",
    "    training_logits = tf.identity(training_decoder_output.rnn_output, 'logits')\n",
    "    predicting_logits = tf.identity(predicting_decoder_output.sample_id, name='predictions')\n",
    "    \n",
    "    masks = tf.sequence_mask(target_sequence_length, max_target_sequence_length, dtype=tf.float32, name='masks')\n",
    "\n",
    "    with tf.name_scope(\"optimization\"):\n",
    "        \n",
    "        # Loss function\n",
    "        cost = tf.contrib.seq2seq.sequence_loss(\n",
    "            training_logits,\n",
    "            targets,\n",
    "            masks)\n",
    "\n",
    "        # Optimizer\n",
    "        optimizer = tf.train.AdamOptimizer(lr)\n",
    "\n",
    "        # Gradient Clipping 基于定义的min与max对tesor数据进行截断操作，目的是为了应对梯度爆发或者梯度消失的情况\n",
    "        gradients = optimizer.compute_gradients(cost)\n",
    "        capped_gradients = [(tf.clip_by_value(grad, -5., 5.), var) for grad, var in gradients if grad is not None]\n",
    "        train_op = optimizer.apply_gradients(capped_gradients)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Batches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def pad_sentence_batch(sentence_batch, pad_int):\n",
    "    '''\n",
    "    对batch中的序列进行补全，保证batch中的每行都有相同的sequence_length\n",
    "    \n",
    "    参数：\n",
    "    - sentence batch\n",
    "    - pad_int: <PAD>对应索引号\n",
    "    '''\n",
    "    max_sentence = max([len(sentence) for sentence in sentence_batch])\n",
    "    return [sentence + [pad_int] * (max_sentence - len(sentence)) for sentence in sentence_batch]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def get_batches(targets, sources, batch_size, source_pad_int, target_pad_int):\n",
    "    '''\n",
    "    定义生成器，用来获取batch\n",
    "    '''\n",
    "    for batch_i in range(0, len(sources)//batch_size):\n",
    "        start_i = batch_i * batch_size\n",
    "        sources_batch = sources[start_i:start_i + batch_size]\n",
    "        targets_batch = targets[start_i:start_i + batch_size]\n",
    "        # 补全序列\n",
    "        pad_sources_batch = np.array(pad_sentence_batch(sources_batch, source_pad_int))\n",
    "        pad_targets_batch = np.array(pad_sentence_batch(targets_batch, target_pad_int))\n",
    "        \n",
    "        #记录每条记录的长度\n",
    "        pad_targets_lengths = []\n",
    "        for target in pad_targets_batch:\n",
    "            pad_targets_lengths.append(len(target))\n",
    "        \n",
    "        pad_source_lengths = []\n",
    "        for source in pad_sources_batch:\n",
    "            pad_source_lengths.append(len(source))\n",
    "        \n",
    "        yield pad_targets_batch, pad_sources_batch, pad_targets_lengths, pad_source_lengths"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Train"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch   1/60 Batch   50/77 - Training Loss:  2.332  - Validation loss:  2.091\n",
      "Epoch   2/60 Batch   50/77 - Training Loss:  1.803  - Validation loss:  1.593\n",
      "Epoch   3/60 Batch   50/77 - Training Loss:  1.550  - Validation loss:  1.379\n",
      "Epoch   4/60 Batch   50/77 - Training Loss:  1.343  - Validation loss:  1.184\n",
      "Epoch   5/60 Batch   50/77 - Training Loss:  1.230  - Validation loss:  1.077\n",
      "Epoch   6/60 Batch   50/77 - Training Loss:  1.096  - Validation loss:  0.956\n",
      "Epoch   7/60 Batch   50/77 - Training Loss:  0.993  - Validation loss:  0.849\n",
      "Epoch   8/60 Batch   50/77 - Training Loss:  0.893  - Validation loss:  0.763\n",
      "Epoch   9/60 Batch   50/77 - Training Loss:  0.808  - Validation loss:  0.673\n",
      "Epoch  10/60 Batch   50/77 - Training Loss:  0.728  - Validation loss:  0.600\n",
      "Epoch  11/60 Batch   50/77 - Training Loss:  0.650  - Validation loss:  0.539\n",
      "Epoch  12/60 Batch   50/77 - Training Loss:  0.594  - Validation loss:  0.494\n",
      "Epoch  13/60 Batch   50/77 - Training Loss:  0.560  - Validation loss:  0.455\n",
      "Epoch  14/60 Batch   50/77 - Training Loss:  0.502  - Validation loss:  0.411\n",
      "Epoch  15/60 Batch   50/77 - Training Loss:  0.464  - Validation loss:  0.380\n",
      "Epoch  16/60 Batch   50/77 - Training Loss:  0.428  - Validation loss:  0.352\n",
      "Epoch  17/60 Batch   50/77 - Training Loss:  0.394  - Validation loss:  0.323\n",
      "Epoch  18/60 Batch   50/77 - Training Loss:  0.364  - Validation loss:  0.297\n",
      "Epoch  19/60 Batch   50/77 - Training Loss:  0.335  - Validation loss:  0.270\n",
      "Epoch  20/60 Batch   50/77 - Training Loss:  0.305  - Validation loss:  0.243\n",
      "Epoch  21/60 Batch   50/77 - Training Loss:  0.311  - Validation loss:  0.248\n",
      "Epoch  22/60 Batch   50/77 - Training Loss:  0.253  - Validation loss:  0.203\n",
      "Epoch  23/60 Batch   50/77 - Training Loss:  0.227  - Validation loss:  0.182\n",
      "Epoch  24/60 Batch   50/77 - Training Loss:  0.204  - Validation loss:  0.165\n",
      "Epoch  25/60 Batch   50/77 - Training Loss:  0.184  - Validation loss:  0.150\n",
      "Epoch  26/60 Batch   50/77 - Training Loss:  0.166  - Validation loss:  0.136\n",
      "Epoch  27/60 Batch   50/77 - Training Loss:  0.150  - Validation loss:  0.124\n",
      "Epoch  28/60 Batch   50/77 - Training Loss:  0.135  - Validation loss:  0.113\n",
      "Epoch  29/60 Batch   50/77 - Training Loss:  0.121  - Validation loss:  0.103\n",
      "Epoch  30/60 Batch   50/77 - Training Loss:  0.109  - Validation loss:  0.094\n",
      "Epoch  31/60 Batch   50/77 - Training Loss:  0.098  - Validation loss:  0.086\n",
      "Epoch  32/60 Batch   50/77 - Training Loss:  0.088  - Validation loss:  0.079\n",
      "Epoch  33/60 Batch   50/77 - Training Loss:  0.079  - Validation loss:  0.073\n",
      "Epoch  34/60 Batch   50/77 - Training Loss:  0.071  - Validation loss:  0.067\n",
      "Epoch  35/60 Batch   50/77 - Training Loss:  0.063  - Validation loss:  0.062\n",
      "Epoch  36/60 Batch   50/77 - Training Loss:  0.057  - Validation loss:  0.057\n",
      "Epoch  37/60 Batch   50/77 - Training Loss:  0.052  - Validation loss:  0.053\n",
      "Epoch  38/60 Batch   50/77 - Training Loss:  0.047  - Validation loss:  0.049\n",
      "Epoch  39/60 Batch   50/77 - Training Loss:  0.043  - Validation loss:  0.045\n",
      "Epoch  40/60 Batch   50/77 - Training Loss:  0.039  - Validation loss:  0.042\n",
      "Epoch  41/60 Batch   50/77 - Training Loss:  0.036  - Validation loss:  0.039\n",
      "Epoch  42/60 Batch   50/77 - Training Loss:  0.033  - Validation loss:  0.037\n",
      "Epoch  43/60 Batch   50/77 - Training Loss:  0.030  - Validation loss:  0.034\n",
      "Epoch  44/60 Batch   50/77 - Training Loss:  0.028  - Validation loss:  0.032\n",
      "Epoch  45/60 Batch   50/77 - Training Loss:  0.026  - Validation loss:  0.029\n",
      "Epoch  46/60 Batch   50/77 - Training Loss:  0.024  - Validation loss:  0.028\n",
      "Epoch  47/60 Batch   50/77 - Training Loss:  0.027  - Validation loss:  0.029\n",
      "Epoch  48/60 Batch   50/77 - Training Loss:  0.030  - Validation loss:  0.030\n",
      "Epoch  49/60 Batch   50/77 - Training Loss:  0.023  - Validation loss:  0.026\n",
      "Epoch  50/60 Batch   50/77 - Training Loss:  0.021  - Validation loss:  0.024\n",
      "Epoch  51/60 Batch   50/77 - Training Loss:  0.019  - Validation loss:  0.022\n",
      "Epoch  52/60 Batch   50/77 - Training Loss:  0.017  - Validation loss:  0.021\n",
      "Epoch  53/60 Batch   50/77 - Training Loss:  0.016  - Validation loss:  0.020\n",
      "Epoch  54/60 Batch   50/77 - Training Loss:  0.015  - Validation loss:  0.019\n",
      "Epoch  55/60 Batch   50/77 - Training Loss:  0.014  - Validation loss:  0.018\n",
      "Epoch  56/60 Batch   50/77 - Training Loss:  0.013  - Validation loss:  0.018\n",
      "Epoch  57/60 Batch   50/77 - Training Loss:  0.012  - Validation loss:  0.017\n",
      "Epoch  58/60 Batch   50/77 - Training Loss:  0.011  - Validation loss:  0.016\n",
      "Epoch  59/60 Batch   50/77 - Training Loss:  0.011  - Validation loss:  0.016\n",
      "Epoch  60/60 Batch   50/77 - Training Loss:  0.010  - Validation loss:  0.015\n",
      "Model Trained and Saved\n"
     ]
    }
   ],
   "source": [
    "# 将数据集分割为train和validation\n",
    "train_source = source_int[batch_size:]\n",
    "train_target = target_int[batch_size:]\n",
    "# 留出一个batch进行验证\n",
    "valid_source = source_int[:batch_size]\n",
    "valid_target = target_int[:batch_size]\n",
    "(valid_targets_batch, valid_sources_batch, valid_targets_lengths, valid_sources_lengths) = next(get_batches(valid_target, valid_source, batch_size,\n",
    "                           source_letter_to_int['<PAD>'],\n",
    "                           target_letter_to_int['<PAD>']))\n",
    "\n",
    "display_step = 50 # 每隔50轮输出loss\n",
    "\n",
    "checkpoint = \"trained_model.ckpt\" \n",
    "with tf.Session(graph=train_graph) as sess:\n",
    "    sess.run(tf.global_variables_initializer())\n",
    "        \n",
    "    for epoch_i in range(1, epochs+1):\n",
    "        for batch_i, (targets_batch, sources_batch, targets_lengths, sources_lengths) in enumerate(\n",
    "                get_batches(train_target, train_source, batch_size,\n",
    "                           source_letter_to_int['<PAD>'],\n",
    "                           target_letter_to_int['<PAD>'])):\n",
    "            \n",
    "            _, loss = sess.run(\n",
    "                [train_op, cost],\n",
    "                {input_data: sources_batch,\n",
    "                 targets: targets_batch,\n",
    "                 lr: learning_rate,\n",
    "                 target_sequence_length: targets_lengths,\n",
    "                 source_sequence_length: sources_lengths})\n",
    "\n",
    "            if batch_i % display_step == 0:\n",
    "                \n",
    "                # 计算validation loss\n",
    "                validation_loss = sess.run(\n",
    "                [cost],\n",
    "                {input_data: valid_sources_batch,\n",
    "                 targets: valid_targets_batch,\n",
    "                 lr: learning_rate,\n",
    "                 target_sequence_length: valid_targets_lengths,\n",
    "                 source_sequence_length: valid_sources_lengths})\n",
    "                \n",
    "                print('Epoch {:>3}/{} Batch {:>4}/{} - Training Loss: {:>6.3f}  - Validation loss: {:>6.3f}'\n",
    "                      .format(epoch_i,\n",
    "                              epochs, \n",
    "                              batch_i, \n",
    "                              len(train_source) // batch_size, \n",
    "                              loss, \n",
    "                              validation_loss[0]))\n",
    "\n",
    "    \n",
    "    \n",
    "    # 保存模型\n",
    "    saver = tf.train.Saver()\n",
    "    saver.save(sess, checkpoint)\n",
    "    print('Model Trained and Saved')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 预测"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "def source_to_seq(text):\n",
    "    '''\n",
    "    对源数据进行转换\n",
    "    '''\n",
    "    sequence_length = 7\n",
    "    return [source_letter_to_int.get(word, source_letter_to_int['<UNK>']) for word in text] + [source_letter_to_int['<PAD>']]*(sequence_length-len(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {
    "collapsed": false,
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Restoring parameters from ./trained_model.ckpt\n",
      "原始输入: common\n",
      "\n",
      "Source\n",
      "  Word 编号:    [20, 28, 6, 6, 28, 5, 0]\n",
      "  Input Words: c o m m o n <PAD>\n",
      "\n",
      "Target\n",
      "  Word 编号:       [20, 6, 6, 5, 28, 28, 3]\n",
      "  Response Words: c m m n o o <EOS>\n"
     ]
    }
   ],
   "source": [
    "# 输入一个单词\n",
    "input_word = 'common'\n",
    "text = source_to_seq(input_word)\n",
    "\n",
    "checkpoint = \"./trained_model.ckpt\"\n",
    "\n",
    "loaded_graph = tf.Graph()\n",
    "with tf.Session(graph=loaded_graph) as sess:\n",
    "    # 加载模型\n",
    "    loader = tf.train.import_meta_graph(checkpoint + '.meta')\n",
    "    loader.restore(sess, checkpoint)\n",
    "\n",
    "    input_data = loaded_graph.get_tensor_by_name('inputs:0')\n",
    "    logits = loaded_graph.get_tensor_by_name('predictions:0')\n",
    "    source_sequence_length = loaded_graph.get_tensor_by_name('source_sequence_length:0')\n",
    "    target_sequence_length = loaded_graph.get_tensor_by_name('target_sequence_length:0')\n",
    "    \n",
    "    answer_logits = sess.run(logits, {input_data: [text]*batch_size, \n",
    "                                      target_sequence_length: [len(text)]*batch_size, \n",
    "                                      source_sequence_length: [len(text)]*batch_size})[0] \n",
    "\n",
    "\n",
    "pad = source_letter_to_int[\"<PAD>\"] \n",
    "\n",
    "print('原始输入:', input_word)\n",
    "\n",
    "print('\\nSource')\n",
    "print('  Word 编号:    {}'.format([i for i in text]))\n",
    "print('  Input Words: {}'.format(\" \".join([source_int_to_letter[i] for i in text])))\n",
    "\n",
    "print('\\nTarget')\n",
    "print('  Word 编号:       {}'.format([i for i in answer_logits if i != pad]))\n",
    "print('  Response Words: {}'.format(\" \".join([target_int_to_letter[i] for i in answer_logits if i != pad])))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "anaconda-cloud": {},
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.2"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
